import copy
import os
import numpy as np
from tqdm import tqdm
import pandas as pd
from agent import SportsAgent
from evaluate.evaluate_distrib_rl import empirical_action_output_with_features
from generic.data_util import load_config, divide_dataset_according2date, read_args
from generic.plot_util import plot_heatmap, plot_histogram, plot_scatter

PLOT_ACTIONS = ['assist',
                'block',
                'carry',
                'check',
                'controlledbreakout',
                'controlledentryagainst',
                'dumpin',
                'dumpinagainst',
                'dumpout',
                'faceoff',
                'goal',
                'icing',
                'lpr',
                'offside',
                'pass',
                'pass1timer',
                'penalty',
                'pressure',
                'puckprotection',
                'reception',
                'receptionprevention',
                'shot',
                'shot1timer']


def calculate_action_count(all_files, target_actions, is_home, agent, mode=None, debug_mode=False):
    training_games, validate_games, testing_games = divide_dataset_according2date(all_data_files=all_files,
                                                                               train_rate=agent.train_rate,
                                                                               sports=agent.sports,
                                                                               if_split=agent.apply_data_date_div
                                                                               )
    action_outcome_dict = {}
    if mode == 'train':
        games = training_games
    elif mode == 'validate':
        games = validate_games
    elif mode == 'test':
        games = testing_games
    action_num_game_num = 100
    # if debug_mode:
    #     action_num_game_num = 1
    #     games = games[:2]

    action_nums_dict = {}
    for i in tqdm(range(len(games))):
        if i == action_num_game_num:
            for target_action in target_actions:
                if target_action not in action_outcome_dict.keys():
                    continue
                # action_nums_dict.append(len(action_outcome_dict[target_action]))
                action_nums_dict.update({target_action: len(action_outcome_dict[target_action])})

        game_name = games[i]
        action_outcome_dict = empirical_action_output_with_features(agent=agent,
                                                                    game_name=game_name,
                                                                    is_home=is_home,
                                                                    action_outcome_dict=action_outcome_dict, )

    for target_action in target_actions:
        if target_action not in action_outcome_dict.keys():
            continue
        outcome_location_data = action_outcome_dict[target_action]

        # if is_home:
        #     outcome_data = np.asarray(outcome_location_data)[:, 0]
        # else:
        #     outcome_data = np.asarray(outcome_location_data[:, 1])

        outcome_location_data = np.asarray(outcome_location_data)[:, -2:]

        bin_size = 20
        x_min = -100
        x_max = 100
        y_min = -50
        y_max = 50
        x_dim = int((x_max - x_min) / bin_size)
        y_dim = int((y_max - y_min) / bin_size)
        bin_feature_location_store = {}
        for idx in range(len(outcome_location_data)):
            x_float = outcome_location_data[idx][0] - x_min
            x_idx = int((x_float - x_float % bin_size) / bin_size)
            y_float = outcome_location_data[idx][1] - y_min
            y_idx = int((y_float - y_float % bin_size) / bin_size)
            # tmp = bin_sum_store[x_idx][y_idx]
            bin_key = "({0},{1})".format(y_dim - y_idx - 1, x_idx)
            if bin_key in bin_feature_location_store.keys():
                bin_feature_location_store[bin_key] += 1
            else:
                bin_feature_location_store[bin_key] = 1

        bin_num_values = np.zeros([y_dim, x_dim])
        for i in range(len(bin_num_values)):
            for j in range(len(bin_num_values[i])):
                bin_key = "({0},{1})".format(i, j)
                if bin_key in bin_feature_location_store:
                    bin_num_values[i][j] = bin_feature_location_store[bin_key]
                else:
                    bin_num_values[i][j] = 0

        print(bin_num_values)
        plot_heatmap(data_store=bin_num_values[:, 5:],
                     plot_name='/distribution_shift/heat_map_{0}_num_bin_{1}_{2}'.format(mode,
                                                                                         bin_size,
                                                                                         target_action))

    outcome_location_data = action_outcome_dict['goal']
    outcome_location_data_all_x = []
    outcome_location_data_all_y = []
    outcome_location_data_all_z = []
    outcome_location_data = np.asarray(outcome_location_data)[:, -2:]
    for idx in range(len(outcome_location_data)):
        outcome_location_data_all_x.append(outcome_location_data[idx][0])
        outcome_location_data_all_y.append(outcome_location_data[idx][1])
        outcome_location_data_all_z.append(1)

    plot_name = './heatmaps_for_calibration/distribution_shift/scatter_map_{0}_goal'.format(mode)
    scatter_data = np.asarray(zip(outcome_location_data_all_x, outcome_location_data_all_y))
    plot_scatter(scatter_data=scatter_data,
                 # z=[outcome_location_data_all_z],
                 labels=['goal'],
                 plot_name=plot_name)

    return action_nums_dict


def train(args):
    config, debug_mode, log_file_path = load_config(args)
    if log_file_path is not None:
        log_file = open(log_file_path, 'w')
    else:
        log_file = None

    if args.DEBUG_MODE:
        debug_mode = True
        debug_msg = 'debug_'
    else:
        debug_mode = False
        debug_msg = ''

    is_home = 1
    agent = SportsAgent(config=config, log_file=log_file)
    all_files = sorted(os.listdir(agent.train_data_path))
    action_nums_dict_all = {}
    validate_action = PLOT_ACTIONS
    for mode in ['train', 'validate', 'test']:
        action_nums_dict = calculate_action_count(all_files=all_files,
                                                  target_actions=PLOT_ACTIONS,
                                                  is_home=is_home,
                                                  agent=agent,
                                                  mode=mode,
                                                  debug_mode=debug_mode)

        for action in copy.copy(validate_action):
            if action not in action_nums_dict.keys():
                validate_action.remove(action)
        action_nums_dict_all.update({mode: action_nums_dict})

    action_plot_nums_dict = {}
    for mode in ['train', 'validate', 'test']:
        action_plot_nums = []
        for action in validate_action:
            # try:
            action_plot_nums.append(action_nums_dict_all[mode][action])
            # except:
            #     action_plot_nums.append(0)
        action_plot_nums_dict.update({mode: action_plot_nums})

    import matplotlib.pyplot as plt
    # fig = plt.figure(figsize=(10, 10))
    m1_t = pd.DataFrame(action_plot_nums_dict)
    width = .35  # width of a ba
    m1_t[['train', 'validate', 'test']].plot(kind='bar', width=width, figsize=(15, 15))

    ax = plt.gca()
    plt.xlim([-width, len(m1_t['train']) - width])
    y_pos = range(len(validate_action))
    plt.xticks(y_pos, validate_action, rotation=90, fontsize=16)
    # plt.show()
    bar_plot_label = './heatmaps_for_calibration/distribution_shift/hist_action_count.png'
    plt.savefig(bar_plot_label)


if __name__ == "__main__":
    args = read_args()
    if int(args.TRAIN_FLAG):
        train(args)
